import os
import pathlib
import wandb
import torch
import os
import sys
import numpy as np
sys.path.insert(0, os.getcwd())
import tools.utils as utils
from vis.img_plt import plot_mix_img, save_ori_img, save_mix_i_img
from datatool.datatool import get_dl_tr
from tools.args import get_general_args
import random
import shutil


def proj(args, x):
    if args.inex_proj == 'clip':
        x_proj = torch.clamp(x, *input_range)
    elif args.inex_proj in ['1', '2']:
        if dt == 'mnist':
            # transform [0, 1] -> [-1, 1] by 2*(x - [0.5, ..., 0.5])
            raise NotImplementedError
        # proj
        p = 1 if args.inex_proj == '1' else 2
        sh = x.shape
        nm = torch.norm(x.view(sh[0], -1), p=p, dim=1)
        x_proj =  x / nm[:, None, None, None]
        if dt == 'mnist':
            # transform
            raise NotImplementedError
    elif args.inex_proj == 'norm':
        sh = x.shape
        x_max = x.view(sh[0], -1).max(-1)[0]
        x_min = x.view(sh[0], -1).min(-1)[0]
        x_proj = (x - x_min[:, None, None, None])/(x_max - x_min)[:, None, None, None]
    else:
        x_proj = x
    return x_proj


args = get_general_args()
legend = utils.get_legend()
output_path = os.path.join(args.output_dir, args.proj_name, args.exp)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
wandb_path= os.path.join(args.wandb_dir, args.proj_name, args.exp)
pathlib.Path(wandb_path).mkdir(parents=True, exist_ok=True)
wandb.init(project="inex_mix", name=legend, tags=args.tags,
            dir=wandb_path, entity=args.wandb_entity)
wandb.config.update(args, allow_val_change=True)
vl_dl = get_dl_tr(args, shuffle=False, no_aug=True)
n_cls = args.num_labels

n_pts = 19
nc = args.bsz//n_cls

if args.inex_linsp:
    lamb = torch.linspace(-4, 5, n_pts)[:, None, None, None]
else:
    # it must have: -0.5, 0, 0.5, 1, 1.5
    # 0, 1: 2
    # 0 -> 1: 3
    # 1 + : 7
    # 1 + : 7
    arn = torch.linspace(0, 7, 8)
    pw = torch.pow(arn, 2)
    sq_sp = pw / (pw[2] * 2)
    lamb = torch.cat([-sq_sp.flip(0), sq_sp[1:3], 1 - sq_sp[1:2], \
                      (sq_sp + 1)])[:, None, None, None]
    n_pts = len(lamb) #19

alpha = torch.linspace(-1, 1, n_pts)[:, None, None, None]
dt = args.dataset
input_range = (0, 1) if dt == 'mnist' else (-1, 1)
dt = args.dataset
print(output_path)

def mixing_all_coef():
    flst = list()
    i_count = 0
    nc_half = nc // 2
    yf = torch.arange(n_cls).repeat(nc_half, 1).contiguous().reshape(n_cls, nc_half)
    for i_iter, (x, y) in enumerate(vl_dl):
        if  args.inex_nphase1 > i_iter:
            xbyy = torch.stack([x[y==i] for i in range(n_cls)], dim=0)
            xbyy = xbyy[:, :n_pts]
            _flst = save_ori_img(args, dt, output_path, xbyy, n_pts, n_cls, i_iter)
        elif args.inex_nbreak <= i_iter:
            break
        elif args.inex_nphase1 <= i_iter:
            _flst = list()
            xbyy = torch.stack([x[y==i] for i in range(n_cls)], dim=0)
            xm = xbyy[:, :nc_half]
            xf = xbyy[:, nc_half:]
            f_inds = torch.argsort(torch.rand(*xf.shape[:2]), dim=-1)
            xf_shuffled = xf[torch.arange(xf.shape[0]).unsqueeze(-1), f_inds]
            xf_t = xf_shuffled.transpose(0, 1).contiguous().reshape(*xf.shape)
            #f_inds.transpose(0, 1).contiguous().reshape(*f_inds.shape)
            import ipdb; ipdb.set_trace()
            for i in range(n_cls):
                for j in range(nc_half):
                    count = (i_iter-args.inex_nphase1) * n_cls * nc_half \
                            + i * nc_half + j
                    if i_count != count:
                        print(count)
                        import ipdb; ipdb.set_trace()
                    xi = xm[i, j].unsqueeze(0); xj = xf_t[i, j].unsqueeze(0)
                    yi = i; yj = yf[i, j]
                    sh = xi.shape
                    if args.inex_pre_proj:
                        xi = proj(args, xi)
                        xj = proj(args, xj)
                    x_mix = (1-lamb) * xi + lamb * xj.unsqueeze(0)
                    if args.inex_noise:
                        noise = torch.FloatTensor(xi.shape).uniform_(*input_range)
                        noise_mix = alpha * noise
                        x_n = x_mix.unsqueeze(0) + noise_mix.unsqueeze(1)
                    else:
                        x_n = x_mix.squeeze(0)
                    x_proj = proj(args, x_n) if args.inex_post_proj else x_n
                    flst_sub = plot_mix_img(args, dt, output_path, x_proj, \
                                            yi, yj, n_pts, i_count)
                    _flst.extend(flst_sub)
                    i_count += 1
        flst.extend(_flst)
    with open(os.path.join(output_path, 'flst.txt'), 'w') as fp:
        for item in flst:
            # write each item on a new line
            fp.write("%s\n" % item)
        print('Done')


def mixing_coef_with_cnt(lamb):
    nc_ori = nc // 9
    nc_half = nc_ori*4
    lamb_wo_ori = torch.cat((lamb[:8], lamb[9:10], lamb[11:]), dim=0)
    # print(nc, nc_ori, nc_half)
    # print(lamb_wo_ori)
    # import ipdb; ipdb.set_trace()
    for i_iter, (x, y) in enumerate(vl_dl):
        set_path = os.path.join(output_path, 'set_tmp_{}'.format(i_iter))
        pathlib.Path(set_path).mkdir(parents=True, exist_ok=True)
        xbyy = torch.stack([x[y==i] for i in range(n_cls)], dim=0)
        xo = xbyy[:, :nc_ori]
        flst = save_ori_img(args, dt, set_path, xo, nc_ori, n_cls, i_iter)
        xm = xbyy[:, nc_ori:nc_ori*5]
        xf = xbyy[:, nc_ori*5:]
        f_inds = torch.argsort(torch.rand(*xf.shape[:2]), dim=-1)
        xf_shuffled = xf[torch.arange(xf.shape[0]).unsqueeze(-1), f_inds]
        xf_t = xf_shuffled.transpose(0, 1).contiguous().reshape(*xf.shape)
        i_count = 0
        for i in range(n_cls):
            for j in range(nc_half):
                xi = xm[i, j].unsqueeze(0)
                xj = xf_t[i, j].unsqueeze(0)
                yi = i; yj = j % 10
                sh = xi.shape
                i_lamb = i_count % (n_pts -2)
                lamb = lamb_wo_ori[i_lamb]
                # if args.inex_pre_proj:
                    # xi = proj(args, xi)
                #     xj = proj(args, xj)
                x_mix = (1-lamb) * xi + lamb * xj.unsqueeze(0)
                if args.inex_noise:
                    noise = torch.FloatTensor(xi.shape).uniform_(*input_range)
                    noise_mix = alpha * noise
                    x_n = x_mix.unsqueeze(0) + noise_mix.unsqueeze(1)
                else:
                    x_n = x_mix.squeeze(0)
                x_proj = proj(args, x_n) if args.inex_post_proj else x_n
                flst_sub = save_mix_i_img(args, dt, set_path, x_proj, \
                                        yi, yj, i_lamb, i_iter, i_count)
                flst.extend(flst_sub)
                i_count += 1
        # if i_iter == 1:
        #     import ipdb; ipdb.set_trace()
        arng = np.arange(250).reshape(-1, 2)
        for i in range(2):
            i_set = i_iter*2+i
            flst_sub = list()
            sub_path = os.path.join(output_path, 'set_{}'.format(i_set))
            pathlib.Path(sub_path).mkdir(parents=True, exist_ok=True)
            for j in arng[:, i]:
                fname = flst[j]
                flst_sub.append(fname)
                shutil.move(os.path.join(set_path, fname),
                            os.path.join(sub_path, fname))
            random.shuffle(flst_sub)
            with open(os.path.join(sub_path, 'flst_{}.txt'.format(i_set)),
                    'w') as fp:
                for item in flst_sub:
                    # write each item on a new line
                    fp.write("%s\n" % item)
                print(str(i_set) + ' done')
        shutil.rmtree(set_path)

mixing_coef_with_cnt(lamb)
